Skip to content

Implement factorised pointwise probabilities#1331

Open
penelopeysm wants to merge 8 commits intomainfrom
py/pointwise
Open

Implement factorised pointwise probabilities#1331
penelopeysm wants to merge 8 commits intomainfrom
py/pointwise

Conversation

@penelopeysm
Copy link
Copy Markdown
Member

@penelopeysm penelopeysm commented Mar 22, 2026

Closes #1038

  • Implement it for the base methods
  • Implementing it for the chains methods
  • Tests
  • Docs
julia> using DynamicPPL, Distributions, LinearAlgebra

julia> @model function f(y)
                  x ~ MvNormal(zeros(2), I)
                  y ~ MvNormal(zeros(2), I)
              end
f (generic function with 2 methods)

julia> model = f([1.0, 1.0])
Model{typeof(f), (:y,), (), (), Tuple{Vector{Float64}}, Tuple{}, DefaultContext, false}(f, (y = [1.0, 1.0],), NamedTuple(), DefaultContext())

julia> vnt = rand(model)
VarNamedTuple
└─ x => [1.1535230266859213, -0.890066563186062]

julia> pointwise_logdensities(model, InitFromParams(vnt))
VarNamedTuple
├─ x => -2.899293996407594
└─ y => -2.8378770664093453

julia> pointwise_logdensities(model, InitFromParams(vnt); factorize=true)
VarNamedTuple
├─ x => [-1.5842462197519973, -1.3150477766555968]
└─ y => [-1.4189385332046727, -1.4189385332046727]

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Mar 22, 2026

Benchmark Report

  • this PR's head: 4ee4da3627ff7053d0535e0fb9b8d39bcf7bec5e
  • base branch: 81a245a63bcb9192ffb4ed936e9ca6612137c5e2

Computer Information

Julia Version 1.11.9
Commit 53a02c0720c (2026-02-06 00:27 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

┌───────────────────────┬───────┬─────────────┬────────┬───────────────────────────────┬────────────────────────────┬─────────────────────────────────┐
│                       │       │             │        │       t(eval) / t(ref)        │     t(grad) / t(eval)      │        t(grad) / t(ref)         │
│                       │       │             │        │ ─────────┬──────────┬──────── │ ───────┬─────────┬──────── │ ──────────┬───────────┬──────── │
│                 Model │   Dim │  AD Backend │ Linked │     base │  this PR │ speedup │   base │ this PR │ speedup │      base │   this PR │ speedup │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│               Dynamic │    10 │    mooncake │   true │   268.40 │   267.17 │    1.00 │   8.49 │    8.57 │    0.99 │   2279.97 │   2290.10 │    1.00 │
│                   LDA │    12 │ reversediff │   true │  2532.77 │  2544.03 │    1.00 │   2.05 │    2.17 │    0.95 │   5185.29 │   5509.36 │    0.94 │
│   Loop univariate 10k │ 10000 │    mooncake │   true │ 31106.38 │ 29175.64 │    1.07 │   6.48 │    6.67 │    0.97 │ 201641.64 │ 194624.78 │    1.04 │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│    Loop univariate 1k │  1000 │    mooncake │   true │  3113.87 │  3010.12 │    1.03 │   6.43 │    6.41 │    1.00 │  20013.28 │  19303.27 │    1.04 │
│      Multivariate 10k │ 10000 │    mooncake │   true │ 31972.79 │ 31356.85 │    1.02 │  10.35 │    9.85 │    1.05 │ 330928.43 │ 308911.13 │    1.07 │
│       Multivariate 1k │  1000 │    mooncake │   true │  3386.55 │  3350.61 │    1.01 │   9.28 │    9.19 │    1.01 │  31423.99 │  30806.94 │    1.02 │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│ Simple assume observe │     1 │ forwarddiff │  false │     0.88 │     0.89 │    0.99 │  10.44 │    9.95 │    1.05 │      9.18 │      8.84 │    1.04 │
│           Smorgasbord │   201 │ forwarddiff │  false │   948.48 │  1234.50 │    0.77 │  74.11 │   54.78 │    1.35 │  70288.13 │  67619.86 │    1.04 │
│           Smorgasbord │   201 │      enzyme │   true │  1298.53 │  1235.22 │    1.05 │   4.86 │    4.92 │    0.99 │   6314.59 │   6077.99 │    1.04 │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│           Smorgasbord │   201 │ forwarddiff │   true │  1279.59 │  1244.87 │    1.03 │  73.53 │   69.84 │    1.05 │  94085.32 │  86937.82 │    1.08 │
│           Smorgasbord │   201 │    mooncake │   true │  1280.09 │  1255.19 │    1.02 │   4.69 │    4.70 │    1.00 │   6006.03 │   5896.23 │    1.02 │
│           Smorgasbord │   201 │ reversediff │   true │  1287.07 │  1582.48 │    0.81 │ 125.67 │  101.32 │    1.24 │ 161751.73 │ 160334.94 │    1.01 │
├───────────────────────┼───────┼─────────────┼────────┼──────────┼──────────┼─────────┼────────┼─────────┼─────────┼───────────┼───────────┼─────────┤
│              Submodel │     1 │    mooncake │   true │     0.88 │     0.84 │    1.04 │  28.31 │   26.31 │    1.08 │     24.82 │     22.23 │    1.12 │
└───────────────────────┴───────┴─────────────┴────────┴──────────┴──────────┴─────────┴────────┴─────────┴─────────┴───────────┴───────────┴─────────┘

@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 22, 2026

Codecov Report

❌ Patch coverage is 95.65217% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 82.35%. Comparing base (81a245a) to head (4ee4da3).

Files with missing lines Patch % Lines
src/accumulators/pointwise_logdensities.jl 95.00% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1331   +/-   ##
=======================================
  Coverage   82.35%   82.35%           
=======================================
  Files          49       49           
  Lines        3502     3508    +6     
=======================================
+ Hits         2884     2889    +5     
- Misses        618      619    +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@sethaxen
Copy link
Copy Markdown
Member

cc @sethaxen: Especially regarding PosteriorStats being a relatively heavy dep, is there any reasonable chance that this functionality could be extracted into a separate package that depends only on Distributions + basic stats stuff?

Yeah, I think we can do that. Even in testing the functionality, I needed to implement much of the machinery needed to compute arbitrary conditional/marginal distributions. There have been longstanding issues open for each of these features in Distributions, and they probably ultimately belong there, but since that's not likely to happen soon, I think it could make sense to have a PartitionedDistributions.jl package that implements at least marginal, conditional, and pointwise_conditional_logpdfs (AKA pointwise_conditional_loglikelihoods). I could start by just putting the latter in such a package and register it and then add the rest of the features later.

@penelopeysm
Copy link
Copy Markdown
Member Author

That sounds great! I'm very happy to help out in whatever way I can :)

@sethaxen
Copy link
Copy Markdown
Member

sethaxen commented Mar 24, 2026

Once sethaxen/PartitionedDistributions.jl#5 is merged and there are some basic docs, I'll make a release and register it. The easier marginal/conditional implementations are already merged.

@sethaxen
Copy link
Copy Markdown
Member

Okay I registered PartitionedDistributions.jl, and it'll be available once JuliaRegistries/General#153153 is merged.

@github-actions
Copy link
Copy Markdown
Contributor

DynamicPPL.jl documentation for PR #1331 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR1331/

Comment thread ext/DynamicPPLMCMCChainsExt.jl Outdated
distributions that can be partitioned into blocks, using PartitionedDistributions.jl. For
example, if `factorize=true`, then `y ~ MvNormal(...)` will return a vector of
log-densities, one for each element of `y`. If `factorize=false`, then the log-density for
`y ~ MvNormal(...)` will be a single scalar.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth noting somewhere that these factorized log-densities will only add up to the total log-density when the original density can be completely factorized into independent univariate distributions.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that sounds good

@penelopeysm
Copy link
Copy Markdown
Member Author

I realise there aren't actually any real docs on pointwise likelihoods apart from the docstrings, but also this is quite meaningful for users, so I'll add it to the TuringLang/docs repository rather than DynamicPPL where it'd be a bit hidden.

Comment on lines +19 to +37
If `factorize=true`, additionally attempt to provide factorised log-densities for
distributions that can be partitioned into blocks, using PartitionedDistributions.jl.

For example, if `factorize=true`, then `y ~ MvNormal(...)` will return a vector of
log-densities, one for each element of `y`. The `i`-th element of this vector will be the
conditional log-probability of `y[i]` given all the other elements of `y` (often denoted
`log p(y_{i} | y_{-i})`): in particular this is exactly the log-density required for
leave-one-out cross-validation.

In contrast, if `factorize=false`, then the log-density for `y ~ MvNormal(...)` will be a
single scalar corresponding to `logpdf(MvNormal(...), y)`.

Note that the sum of the factorised log-densities may not, in general, be equal to the
log-density of the full distribution: they will only be equal if the original distribution
can be completely factorised into independent components. For example, if `y ~ MvNormal(μ,
Σ)` where `Σ` is diagonal, then each element of `y` is independent and the sum of the
factorised log-densities will be equal to the log-density of the full distribution. In
contrast, if `Σ` has off-diagonal entries, then the elements of `y` are not independent.
"""
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sethaxen Would you be willing to sense-check my docstring?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(If you want to review generally please feel free to as well, but it's also not your job so please don't feel obliged 🙂)

@penelopeysm penelopeysm marked this pull request as ready for review April 23, 2026 22:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Computing pointwise log-likelihoods without factorizing the likelihood

2 participants